% By Martin M. Andreasen, April 2019
% This function computes the conditional moments up to k_period into the
% future for the control variable in the g-function.
% The computations are derived by using the Perturbation On Perturbation (POP) method.
% IMPORTANT: 1) This is for a "level" approximation under perfect foresight
%            2) The first time period in px, pxx, pxxx is the current time
%            period, i.e. we reproduce gx, gxx, gxxx, if firstOn = 1.
%
function [px,pxx,pxxx,p4x] = CondMoments_4th_levelCE(gx,gxx,gxxx,g4x,hx,hxx,hxxx,h4x,k_period,firstOn,order_app)

% Allocating memory
ny    = size(gx,1);
nx    = size(hx,1);
if firstOn == 1
    k_period = k_period + 1;
end
px      = zeros(ny,k_period,nx);
pxx     = zeros(ny,k_period,nx,nx);
pxxx    = zeros(ny,k_period,nx,nx,nx);
p4x     = zeros(ny,k_period,nx,nx,nx,nx);

% The first time period
if firstOn == 1
    px(:,1,:)          = gx;
    pxx(:,1,:,:)       = gxx;
    pxxx(:,1,:,:,:)    = gxxx;
    p4x(:,1,:,:,:,:)   = g4x;
    startIndex         = 2;
else
    startIndex     = 1;
end
for i=1:ny
    rx   = reshape(gx(i,:),1,nx);
    rxx  = reshape(gxx(i,:,:),nx,nx);
    rxxx = reshape(gxxx(i,:,:,:),nx,nx,nx);
    r4x  = reshape(g4x(i,:,:,:),nx,nx,nx,nx);
    
    for j=startIndex:k_period
        % ************** The first order effects *****************
        px(i,j,:) = rx(1,:)*hx;
        
        %*************** The second order effects ***************
        if order_app > 1
            for alfa1=1:nx
                for alfa2=alfa1:nx
                    pxx(i,j,alfa1,alfa2) = hx(:,alfa1)'*rxx*hx(:,alfa2) + rx(1,:)*squeeze(hxx(:,alfa1,alfa2));
                end
                if alfa1 > 1
                    pxx(i,j,alfa1,1:alfa1-1) = squeeze(pxx(i,j,1:alfa1-1,alfa1))';
                end
            end
        end
        % ***************** The third order effects *****************
        if order_app > 2
            for alfa1=1:nx
                for alfa2=alfa1:nx
                    for alfa3=alfa2:nx
                        %tmp = 0;
                        %for gama3=1:nx
                        %    tmp = tmp + hx(:,alfa1)'*squeeze(rxxx(:,:,gama3))*hx(:,alfa2)*hx(gama3,alfa3);
                        %end
                        tmp = reshape(rxxx(:,:,:),1,nx^3)*kron(hx(:,alfa1),kron(hx(:,alfa2),hx(:,alfa3)));
                        pxxx(i,j,alfa1,alfa2,alfa3) = tmp...
                            +hx(:,alfa1)'*rxx*reshape(hxx(:,alfa2,alfa3),nx,1)...
                            +reshape(hxx(:,alfa1,alfa3),1,nx)*rxx*hx(:,alfa2)...
                            +reshape(hxx(:,alfa1,alfa2),1,nx)*rxx*hx(:,alfa3)...
                            +rx(1,:)*reshape(hxxx(:,alfa1,alfa2,alfa3),nx,1);
                        
                        % Using symmetry for alfa1 and alfa2
                        if alfa1 == alfa2 && alfa2 ~= alfa3 %alfa1==alfa2~=alfa3
                            %pxxx(i,j,alfa1,alfa1,alfa3)= pxxx(i,j,alfa1,alfa2,alfa3);
                            pxxx(i,j,alfa1,alfa3,alfa1) = pxxx(i,j,alfa1,alfa2,alfa3);
                            pxxx(i,j,alfa3,alfa1,alfa1) = pxxx(i,j,alfa1,alfa2,alfa3);
                        end
                        % Using symmetry for alfa2 and alfa3
                        if alfa1 ~= alfa2 && alfa2 == alfa3  %alfa1~=alfa2==alfa3
                            %pxxx(i,j,alfa1,alfa2,alfa2)= pxxx(i,j,alfa1,alfa2,alfa3);
                            pxxx(i,j,alfa2,alfa1,alfa2) = pxxx(i,j,alfa1,alfa2,alfa3);
                            pxxx(i,j,alfa2,alfa2,alfa1) = pxxx(i,j,alfa1,alfa2,alfa3);
                        end
                        % Using symmetry for alfa1,alfa2, and alfa3
                        if alfa1 ~= alfa2 && alfa1 ~= alfa3 &&  alfa2 ~= alfa3 %alfa1~=alfa2~=alfa3
                            %pxxx(i,j,alfa1,alfa2,alfa3) = pxxx(i,j,alfa1,alfa2,alfa3);
                            pxxx(i,j,alfa1,alfa3,alfa2) = pxxx(i,j,alfa1,alfa2,alfa3);
                            pxxx(i,j,alfa3,alfa1,alfa2) = pxxx(i,j,alfa1,alfa2,alfa3);
                            pxxx(i,j,alfa3,alfa2,alfa1) = pxxx(i,j,alfa1,alfa2,alfa3);
                            pxxx(i,j,alfa2,alfa3,alfa1) = pxxx(i,j,alfa1,alfa2,alfa3);
                            pxxx(i,j,alfa2,alfa1,alfa3) = pxxx(i,j,alfa1,alfa2,alfa3);
                        end
                    end
                end
            end
        end
        
        % ***************** Fourth order effects *****************        
        if order_app > 3
            for alfa1=1:nx
                for alfa2=1:nx
                    for alfa3=1:nx
                        for alfa4=1:nx
                            tmp = reshape(r4x(:,:,:,:),1,nx^4)*kron(hx(:,alfa1),kron(hx(:,alfa2),kron(hx(:,alfa3),hx(:,alfa4))));
                            p4x(i,j,alfa1,alfa2,alfa3,alfa4) =  tmp + ...                                                   %1)
                               reshape(rxxx,1,nx^3)*(kron(hx(:,alfa1),kron(hx(:,alfa2),squeeze(hxx(:,alfa3,alfa4)))) + ...  %2)
                                                     kron(hx(:,alfa1),kron(squeeze(hxx(:,alfa2,alfa4)),hx(:,alfa3))) + ...  %3)
                                                     kron(squeeze(hxx(:,alfa1,alfa4)),kron(hx(:,alfa2),hx(:,alfa3))) ) +... %4)
                               reshape(rxxx,1,nx^3)*kron(hx(:,alfa1),kron(squeeze(hxx(:,alfa2,alfa3)),hx(:,alfa4)))  + ...  %5)
                               reshape(rxx,1,nx^2)*(kron(hx(:,alfa1),squeeze(hxxx(:,alfa2,alfa3,alfa4))) +...               %6)
                                                    kron(squeeze(hxx(:,alfa1,alfa4)),squeeze(hxx(:,alfa2,alfa3))) )+...     %7)
                               reshape(rxxx,1,nx^3)*kron(squeeze(hxx(:,alfa1,alfa3)),kron(hx(:,alfa2),hx(:,alfa4)))  + ...  %8)
                               reshape(rxx,1,nx^2)*(kron(squeeze(hxx(:,alfa1,alfa3)),squeeze(hxx(:,alfa2,alfa4))) + ...     %9)
                                                    kron(squeeze(hxxx(:,alfa1,alfa3,alfa4)),hx(:,alfa2))) +...              %10)
                               reshape(rxxx,1,nx^3)*kron(squeeze(hxx(:,alfa1,alfa2)),kron(hx(:,alfa3),hx(:,alfa4)))  + ...  %11)
                               reshape(rxx,1,nx^2)*(kron(squeeze(hxx(:,alfa1,alfa2)),squeeze(hxx(:,alfa3,alfa4))) + ...     %12)
                                                    kron(squeeze(hxxx(:,alfa1,alfa2,alfa4)),hx(:,alfa3)) +...               %13)
                                                    kron(squeeze(hxxx(:,alfa1,alfa2,alfa3)),hx(:,alfa4)))+...               %14)
                               rx(1,:)*squeeze(h4x(:,alfa1,alfa2,alfa3,alfa4));                                             %15) 
                        end
                    end
                end
            end
        end
        % Updating rx
        rx   = reshape(px(i,j,:),1,nx);
        rxx  = reshape(pxx(i,j,:,:),nx,nx);
        rxxx = reshape(pxxx(i,j,:,:,:),nx,nx,nx);
        r4x  = reshape(p4x(i,j,:,:,:,:),nx,nx,nx,nx);
    end
end

% We reexpress the output if k_period = 1
if k_period == 1 && firstOn == 0
    px      = reshape(px(1:ny,k_period,1:nx),ny,nx);
    pxx     = reshape(pxx(1:ny,k_period,1:nx,1:nx),ny,nx,nx);
    pxxx    = reshape(pxxx(1:ny,k_period,1:nx,1:nx,1:nx),ny,nx,nx,nx);
    p4x     = reshape(p4x(1:ny,k_period,1:nx,1:nx,1:nx,1:nx),ny,nx,nx,nx,nx);
end